import torch
import torch.nn as nn

class SR_CNN(nn.Module):
    def __init__(self):
        super(SR_CNN, self).__init__()
        
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),  
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),  
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(65, 32, kernel_size=2, stride=2),   
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2),   
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        self.out_conv = nn.Sequential(
            nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),  
            nn.Sigmoid()  
        )

    def forward(self, x, eta):
        
        x = self.feature_extractor(x)  
        
        
        if eta.dim() == 3:
            eta = eta.unsqueeze(1)  
        
        # Concatenate auxillary noise
        x = torch.cat((x, eta), dim=1)
        
        x = self.up1(x)  
        x = self.up2(x)  
        x = self.out_conv(x)  
        return x
